{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Color FID Benchmark (HQ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
    "os.environ['OMP_NUM_THREADS']='1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import statistics\n",
    "from fastai import *\n",
    "from deoldify.visualize import *\n",
    "import cv2\n",
    "from fid.fid_score import *\n",
    "from fid.inception import *\n",
    "import imageio\n",
    "plt.style.use('dark_background')\n",
    "torch.backends.cudnn.benchmark=True\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"torch.nn.functional\")\n",
    "warnings.filterwarnings(\"ignore\", category=UserWarning, message='.*?retrieve source code for container of type.*?')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#NOTE:  Data should come from here:  'https://datasets.figure-eight.com/figure_eight_datasets/open-images/test_challenge.zip'\n",
    "#NOTE:  Minimum recommmended number of samples is 10K.  Source:  https://github.com/bioinf-jku/TTUR\n",
    "\n",
    "path = Path('data/ColorBenchmark')\n",
    "path_hr = path/'source'\n",
    "path_lr = path/'bandw'\n",
    "path_results = Path('./result_images/ColorBenchmarkFID/artistic')\n",
    "path_rendered = path_results/'rendered'\n",
    "\n",
    "#path = Path('data/DeOldifyColor')\n",
    "#path_hr = path\n",
    "#path_lr = path/'bandw'\n",
    "#path_results = Path('./result_images/ColorBenchmark/edge')\n",
    "#path_rendered = path_results/'rendered'\n",
    "\n",
    "#num_images = 2048\n",
    "#num_images = 15000\n",
    "num_images = 50000\n",
    "render_factor=35\n",
    "fid_batch_size = 4\n",
    "eval_size=299"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inception_model(dims:int):\n",
    "    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]\n",
    "    model = InceptionV3([block_idx])\n",
    "    model.cuda()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_before_images(fn,i):\n",
    "    dest = path_lr/fn.relative_to(path_hr)\n",
    "    dest.parent.mkdir(parents=True, exist_ok=True)\n",
    "    img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
    "    img.save(dest)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def render_images(colorizer, source_dir:Path, filtered_dir:Path, target_dir:Path, render_factor:int, num_images:int)->[(Path, Path, Path)]:\n",
    "    results = []\n",
    "    bandw_list = ImageList.from_folder(path_lr)\n",
    "    bandw_list = bandw_list[:num_images]\n",
    "\n",
    "    if len(bandw_list.items) == 0: return results\n",
    "\n",
    "    results = []\n",
    "    img_iterator = progress_bar(bandw_list.items)\n",
    "\n",
    "    for bandw_path in img_iterator:\n",
    "        target_path = target_dir/bandw_path.relative_to(source_dir)\n",
    "\n",
    "        try:\n",
    "            result_image = colorizer.get_transformed_image(path=bandw_path, render_factor=render_factor)\n",
    "            result_path = Path(str(path_results) + '/' + bandw_path.parent.name + '/' + bandw_path.name)\n",
    "            if not result_path.parent.exists():\n",
    "                result_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "            result_image.save(result_path)\n",
    "            results.append((result_path, bandw_path, target_path))\n",
    "        except Exception as err:\n",
    "            print('Failed to render image.  Skipping.  Details: {0}'.format(err))\n",
    "    \n",
    "    return results "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_fid_score(render_results, bs:int, eval_size:int):\n",
    "    dims = 2048\n",
    "    cuda = True\n",
    "    model = inception_model(dims=dims)\n",
    "    rendered_paths = []\n",
    "    target_paths = []\n",
    "    \n",
    "    for render_result in render_results:\n",
    "        rendered_path, _, target_path = render_result\n",
    "        rendered_paths.append(str(rendered_path))\n",
    "        target_paths.append(str(target_path))\n",
    "        \n",
    "    rendered_m, rendered_s = calculate_activation_statistics(files=rendered_paths, model=model, batch_size=bs, dims=dims, cuda=cuda)\n",
    "    target_m, target_s = calculate_activation_statistics(files=target_paths, model=model, batch_size=bs, dims=dims, cuda=cuda)\n",
    "    fid_score = calculate_frechet_distance(rendered_m, rendered_s, target_m, target_s)\n",
    "    del model\n",
    "    return fid_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create black and whites source images"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Only runs if the directory isn't already created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not path_lr.exists():\n",
    "    il = ImageList.from_folder(path_hr)\n",
    "    parallel(create_before_images, il.items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_results.parent.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Rendering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colorizer = get_image_colorizer(artistic=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_results = render_images(colorizer=colorizer, source_dir=path_lr, target_dir=path_hr, filtered_dir=path_results, render_factor=render_factor, num_images=num_images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Colorizaton Scoring"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_score = calculate_fid_score(render_results, bs=fid_batch_size, eval_size=eval_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('FID Score: ' + str(fid_score))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
